{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# `register_object`\n",
    "\n",
    "`tvm.register_object(type_key=None)` 实现的关键接口是 `_LIB.TVMObjectTypeKey2Index`，函数的作用是根据给定的 `key` 获取对应的类型索引。（根据注释的说明，具体的实现细节可能在其他地方进行定义。如果你需要使用这个函数，可以在代码中包含该函数的声明，并在需要的地方调用它来获取类型索引。）拿到索引后，调用 `_register_object(tindex, cls)` 在 Python 端完成注册。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`_LIB.TVMObjectTypeKey2Index` 的实现如下查找链路：\n",
    "\n",
    "```c++\n",
    "//   src/runtime/object.cc\n",
    "int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex) {\n",
    "  API_BEGIN();\n",
    "  out_tindex[0] = tvm::runtime::ObjectInternal::ObjectTypeKey2Index(type_key);\n",
    "  API_END();\n",
    "}\n",
    "```\n",
    "\n",
    "-> `tvm::runtime::ObjectInternal::ObjectTypeKey2Index` 定义如下：\n",
    "```c++\n",
    "static uint32_t ObjectTypeKey2Index(const std::string& type_key) {\n",
    "  return Object::TypeKey2Index(type_key);\n",
    "}\n",
    "```\n",
    "\n",
    "-> `Object::TypeKey2Index` 定义如下：\n",
    "\n",
    "```c++\n",
    "uint32_t Object::TypeKey2Index(const std::string& key) {\n",
    "  return TypeContext::Global()->TypeKey2Index(key);\n",
    "}\n",
    "```\n",
    "->\n",
    "```c++\n",
    "uint32_t TypeKey2Index(const std::string& skey) {\n",
    "    auto it = type_key2index_.find(skey);\n",
    "    ICHECK(it != type_key2index_.end())\n",
    "        << \"Cannot find type \" << skey\n",
    "        << \". Did you forget to register the node by TVM_REGISTER_NODE_TYPE ?\";\n",
    "    return it->second;\n",
    "  }\n",
    "```\n",
    "->\n",
    "```c++\n",
    "std::unordered_map<std::string, uint32_t> type_key2index_;\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`TVMObjectTypeKey2Index` 接受两个参数：一个指向字符类型的指针 `type_key` 和一个指向无符号整数类型的指针 `out_tindex`。函数的返回类型是 `int`。\n",
    "\n",
    "函数的参数解释如下：\n",
    "- `const char* type_key`：表示类型键的字符串指针。\n",
    "- `unsigned* out_tindex`：指向无符号整数的指针，用于存储转换后的类型索引。\n",
    "\n",
    "函数的返回值解释如下：\n",
    "- 当成功时，返回 `0`。\n",
    "- 当失败时，返回非零值。\n",
    "\n",
    "如果你需要使用这个函数，可以在代码中包含该函数的声明，并在需要的地方调用它来将类型键转换为类型索引。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在上述查找过程发现：`TVM_REGISTER_NODE_TYPE` 宏，用于注册 key2index 的绑定。\n",
    "\n",
    "```c++\n",
    "#define TVM_REGISTER_NODE_TYPE(TypeName)                                             \\\n",
    "  TVM_REGISTER_OBJECT_TYPE(TypeName);                                                \\\n",
    "  TVM_REGISTER_REFLECTION_VTABLE(TypeName, ::tvm::detail::ReflectionTrait<TypeName>) \\\n",
    "      .set_creator([](const std::string&) -> ObjectPtr<Object> {                     \\\n",
    "        return ::tvm::runtime::make_object<TypeName>();                              \\\n",
    "      })\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`TVM_REGISTER_NODE_TYPE` 宏用于在 C++ 中注册节点类型。\n",
    "\n",
    "首先，`TVM_REGISTER_NODE_TYPE(TypeName)` 宏定义了函数调用，该函数调用了 `TVM_REGISTER_OBJECT_TYPE(TypeName)` 和 `TVM_REGISTER_REFLECTION_VTABLE(TypeName, ::tvm::detail::ReflectionTrait<TypeName>)` 两个函数。\n",
    "\n",
    "- `TVM_REGISTER_OBJECT_TYPE(TypeName)` 函数用于注册对象类型，将给定的类型名称 `TypeName` 与相应的对象类型关联起来。\n",
    "- `TVM_REGISTER_REFLECTION_VTABLE(TypeName, ::tvm::detail::ReflectionTrait<TypeName>)` 函数用于注册反射虚函数表（vtable），将给定的类型名称 `TypeName` 与相应的反射虚函数表关联起来。这个虚函数表中包含了该类型的反射方法。\n",
    "- 接下来，`.set_creator([](const std::string&) -> ObjectPtr<Object> {...})` 是可选的设置函数，用于指定如何创建该类型的对象实例。在这个例子中，使用了 lambda 表达式作为创建函数，它接受字符串参数，并返回新创建的 `TypeName` 类型的对象实例。\n",
    "\n",
    "综上所述，这段代码的作用是注册节点类型，并提供了创建该类型对象实例的方法。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## `register_object` 示例"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在 `src/tvm_ext.cc` 中定义 `test.BaseObj`：\n",
    "\n",
    "```c++\n",
    "#include <string.h>\n",
    "#include <tvm/runtime/object.h>\n",
    "#include <tvm/node/reflection.h>\n",
    "\n",
    "\n",
    "namespace tvm {\n",
    "namespace runtime {\n",
    "class TestNode :public Object {\n",
    "public:\n",
    "    // 对象字段\n",
    "    std::string name;\n",
    "    // 对象属性\n",
    "    static constexpr const uint32_t _type_index = TypeIndex::kDynamic;\n",
    "    static constexpr const char* _type_key = \"app.TestNode\";\n",
    "    // 告诉 TVM 编译器，TestNode 类是 Object 类的子类，\n",
    "    // 并且需要在编译时进行一些特殊的处理。\n",
    "    TVM_DECLARE_BASE_OBJECT_INFO(TestNode, Object);\n",
    "    void VisitAttrs(AttrVisitor* v) {\n",
    "        v->Visit(\"name\", &name);\n",
    "    }\n",
    "};\n",
    "TVM_REGISTER_NODE_TYPE(TestNode); // 注册节点类型\n",
    "}\n",
    "}\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在 Python 端调用："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<CDLL './libs/libtvm_ext.so', handle 3b36bf0 at 0x7f1e3c4f68f0>"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import tvm\n",
    "from tvm.runtime import Object\n",
    "from tvm._ffi.base import _LIB\n",
    "import ctypes\n",
    "# _LIB.TVMObjectTypeKey2Index\n",
    "\n",
    "def load_dll(lib_path=\"lib/libtvm_ext.so\"):\n",
    "    \"\"\"加载库，函数将被注册到 TVM\"\"\"\n",
    "    # 作为全局加载，这样全局 extern symbol 对其他 dll 是可见的。\n",
    "    # curr_path = f\"{ROOT}/\"\n",
    "    lib = ctypes.CDLL(lib_path, ctypes.RTLD_GLOBAL)\n",
    "    return lib\n",
    "load_dll(\"./libs/libtvm_ext.so\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "app.TestNode(0x46bcb20)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "node = tvm.ir.make_node(\"app.TestNode\", name=\"A\")\n",
    "node"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "或者："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "@tvm._ffi.register_object(\"app.TestNode\")\n",
    "class TestNode(Object):\n",
    "    def __init__(self, handle):\n",
    "        \"\"\"Initialize the function with handle\n",
    "\n",
    "        Parameters\n",
    "        ----------\n",
    "        handle : SymbolHandle\n",
    "            the handle to the underlying C++ Symbol\n",
    "        \"\"\"\n",
    "        super().__init__(handle)\n",
    "        self.handle = handle"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "如果想要改变 `node` 实例的显示内容，也可以在 C++ 端写入：\n",
    "\n",
    "```c++\n",
    "#include <tvm/node/repr_printer.h>\n",
    "TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)\n",
    "    .set_dispatch<TestNode>([](const ObjectRef& ref, ReprPrinter* p) {\n",
    "      auto* op = static_cast<const TestNode*>(ref.get());\n",
    "      p->stream << \"Test(\";\n",
    "      p->stream << \"name=\" << op->name<< \", \";\n",
    "      p->stream << \")\";\n",
    "    });\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Test(name=A, )"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "@tvm._ffi.register_object(\"app.TestNode\")\n",
    "class TestNode(Object):\n",
    "    def __init__(self, handle):\n",
    "        \"\"\"Initialize the function with handle\n",
    "\n",
    "        Parameters\n",
    "        ----------\n",
    "        handle : SymbolHandle\n",
    "            the handle to the underlying C++ Symbol\n",
    "        \"\"\"\n",
    "        super().__init__(handle)\n",
    "        self.handle = handle\n",
    "\n",
    "node = tvm.ir.make_node(\"app.TestNode\", name=\"A\")\n",
    "node"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "或者直接在 Python 端改写："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "TestNode_A"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "@tvm._ffi.register_object(\"app.TestNode\")\n",
    "class TestNode(Object):\n",
    "    def __init__(self, handle):\n",
    "        \"\"\"Initialize the function with handle\n",
    "\n",
    "        Parameters\n",
    "        ----------\n",
    "        handle : SymbolHandle\n",
    "            the handle to the underlying C++ Symbol\n",
    "        \"\"\"\n",
    "        super().__init__(handle)\n",
    "        self.handle = handle\n",
    "\n",
    "    def __repr__(self):\n",
    "        return f\"{self.__class__.__name__}_{self.name}\"\n",
    "\n",
    "node = tvm.ir.make_node(\"app.TestNode\", name=\"A\")\n",
    "node"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tvmz",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
